Support Linear State in SDPA Pipeline#3359
Support Linear State in SDPA Pipeline#3359apaniukov wants to merge 31 commits intoopenvinotoolkit:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR generalizes the “KV cache state” tracking to support fixed-size linear (and hybrid) cache state in stateful/SDPA-based pipelines by introducing a unified cache state type and propagating it through LLM/VLM/speculative decoding codepaths.
Changes:
- Replaced
KVCacheStatewithCacheStateacross pipelines and embedders. - Added cache kind detection (
CacheTypes/get_cache_types) and updated cache-trimming behavior to reset for linear caches. - Wired cache-kind awareness into speculative decoding wrappers and stateful LLM pipeline initialization.
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| src/cpp/src/visual_language/vision_token_pruning_processor.hpp | Updates pruning processor API to use CacheState. |
| src/cpp/src/visual_language/vision_token_pruning_processor.cpp | Updates pruning processor implementation signature to CacheState. |
| src/cpp/src/visual_language/pipeline.cpp | VLM pipeline now uses CacheState when managing chat history/cache trimming. |
| src/cpp/src/visual_language/phi4mm/classes.cpp | Switches embedder history/cache bookkeeping to m_cache_state. |
| src/cpp/src/visual_language/phi3_vision/classes.cpp | Switches embedder history/cache bookkeeping to m_cache_state. |
| src/cpp/src/visual_language/inputs_embedder.hpp | Replaces stored state from KVCacheState to CacheState. |
| src/cpp/src/visual_language/inputs_embedder.cpp | Updates chat/history alignment and rollback bookkeeping to CacheState. |
| src/cpp/src/utils.hpp | Introduces CacheTypes, CacheState, and get_cache_types() API. |
| src/cpp/src/utils.cpp | Implements cache kind detection and updates trim_kv_cache() behavior for linear caches. |
| src/cpp/src/speculative_decoding/stateful/fast_draft_strategy.hpp | Adds CacheTypes member to infer wrapper. |
| src/cpp/src/speculative_decoding/stateful/fast_draft_strategy.cpp | Initializes CacheTypes and uses it to build CacheState for trimming. |
| src/cpp/src/speculative_decoding/stateful/eagle3_strategy.hpp | Adds CacheTypes member to eagle3 infer wrapper base. |
| src/cpp/src/speculative_decoding/stateful/eagle3_strategy.cpp | Initializes CacheTypes and uses it to build CacheState for trimming. |
| src/cpp/src/lm_encoding.hpp | Updates encoding helpers to accept CacheState. |
| src/cpp/src/lm_encoding.cpp | Updates chat-history alignment logic and cache-state updates for CacheState. |
| src/cpp/src/llm/pipeline_stateful.hpp | Renames stored cache reflection to m_cache_state and renames reset helper. |
| src/cpp/src/llm/pipeline_stateful.cpp | Initializes CacheState from model and propagates it through chat/trim logic. |
Comments suppressed due to low confidence (1)
src/cpp/src/utils.cpp:525
- trim_kv_cache() resets the InferRequest when reset_mem_state is set (or when linear cache needs reset), but it returns without clearing cache_state.reset_mem_state / num_tokens_to_trim or updating the token reflection state. This can leave CacheState inconsistent (stale tokens / repeated resets) for subsequent steps. Consider resetting the CacheState fields when a reset happens (and clearing the token reflection if the underlying model state is cleared).
void trim_kv_cache(ov::InferRequest request, CacheState& cache_state, std::optional<AdapterController> adapter_controller) {
if (
cache_state.reset_mem_state
// linear cache stores only the last state, trimming is not possible, so we reset the whole cache in this case
|| (cache_state.num_tokens_to_trim > 0 && cache_state.has_linear())
) {
if (adapter_controller) {
for(auto& state: request.query_state()) {
if(!adapter_controller->has_state_name(state.get_name())) {
state.reset();
}
}
} else {
request.reset_state();
}
return;
|
I converted PR to draft as it labeled as WIP |
src/cpp/src/utils.cpp
Outdated
| CacheTypes get_cache_types(std::shared_ptr<const ov::Model> model) { | ||
| // "ReadValue" node is cache representation in stateful model | ||
| const std::string state_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name); | ||
| CacheTypes cache_types; | ||
|
|
||
| for (const auto op : model->get_ops()) { | ||
| // check input size, as in LoRA adapters case it could be 0 | ||
| if (op->get_type_name() != state_node_type_name || op->get_input_size() < 1) { | ||
| continue; | ||
| } | ||
|
|
||
| // Shape example: [-1,4,0,64] | ||
| auto shape = op->get_input_partial_shape(0); | ||
| const auto rank = shape.rank().get_length(); | ||
| size_t dynamic_axis_count = 0, zero_axis_count = 0; | ||
| for (size_t i = 0; i < rank; i++) { | ||
| if (shape[i].is_dynamic()) { |
There was a problem hiding this comment.
get_cache_types() calls shape.rank().get_length() unconditionally. If a ReadValue input has dynamic rank, get_length() can throw; this would make cache-type detection fail at runtime for some models. Guard with shape.rank().is_dynamic() (skip/continue or handle) before calling get_length(), and similarly avoid iterating dimensions when rank is dynamic.
src/cpp/src/utils.cpp
Outdated
| CacheTypes get_cache_types(std::shared_ptr<const ov::Model> model) { | ||
| // "ReadValue" node is cache representation in stateful model | ||
| const std::string state_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name); | ||
| CacheTypes cache_types; | ||
|
|
||
| for (const auto op : model->get_ops()) { |
There was a problem hiding this comment.
New cache-type detection (get_cache_types) and the linear-cache reset path in trim_kv_cache() introduce non-trivial behavior that can regress chat/history correctness. There are existing gtests for utils (e.g., tests/cpp/utils.cpp), but no coverage for these new paths; please add unit tests covering KV-only, linear-only, and hybrid detection and verifying reset/trim bookkeeping.
| // PA backend does not support linear attention states (conv/SSM caches). | ||
| if (attention_backend == PA_BACKEND | ||
| && utils::has_linear_attention_states(models_dir, properties)) { | ||
| if (utils::explicitly_requires_paged_attention(user_properties) | ||
| || user_properties.find("ATTENTION_BACKEND") != user_properties.end()) { | ||
| GENAI_WARN("PA backend does not support models with linear attention states. The model may work incorrectly."); | ||
| } else { | ||
| attention_backend = SDPA_BACKEND; | ||
| } | ||
| } |
There was a problem hiding this comment.
has_linear_attention_states(models_dir, properties) loads the language model to inspect its states, but VLMPipelineImpl(models_dir, ...) will also read/compile the same language model. Consider reusing the already-loaded language_model from VLMPipelineImpl (or reading it once in this constructor and passing it down) to avoid duplicated model reads/parsing at initialization.
src/cpp/src/utils.cpp
Outdated
| for (const auto op : model.get_ops()) { | ||
| // check input size, as in LoRA adapters case it could be 0 | ||
| if (op->get_type_name() != state_node_type_name || op->get_input_size() < 1) { | ||
| continue; |
There was a problem hiding this comment.
In get_cache_types(), the loop uses for (const auto op : model.get_ops()), which copies each shared_ptr and bumps the atomic ref-count for every op. Using const auto& op avoids that overhead (and is more consistent with performance-sensitive model graph walks).
Co-authored-by: Vladimir Zlobin <vladimir.zlobin@intel.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 28 out of 28 changed files in this pull request and generated 3 comments.
You can also share your feedback on Copilot code review. Take the survey.
| // Shape example: [-1,4,0,64] | ||
| auto shape = op->get_input_partial_shape(0); | ||
| const auto rank = shape.rank().get_length(); | ||
| size_t dynamic_axis_count = 0, zero_axis_count = 0; | ||
| for (size_t i = 0; i < rank; i++) { | ||
| if (shape[i].is_dynamic()) { |
There was a problem hiding this comment.
get_cache_types() calls shape.rank().get_length() without checking shape.rank().is_static(). If a model contains a ReadValue with dynamic rank, this will throw/assert inside OpenVINO. Add a guard (e.g., if (!shape.rank().is_static()) continue;) before using get_length() and iterating the rank.
| // Shape example: [-1,4,0,64] | ||
| auto shape = op->get_input_partial_shape(0); | ||
| if (shape.rank().get_length() != 4) { | ||
| // kv cache should have 4 dimensions | ||
| continue; |
There was a problem hiding this comment.
get_kv_axes_pos() uses shape.rank().get_length() in the != 4 check without verifying that the rank is static. If rank is dynamic, get_length() can throw/assert. Consider checking shape.rank().is_static() first and skipping ReadValue nodes with dynamic rank.
| // get reflection of tokens contained in the kv cache | ||
| utils::KVCacheState& get_kv_cache_state(); | ||
| utils::CacheState& get_kv_cache_state(); |
There was a problem hiding this comment.
The comment says this returns a reflection of tokens contained in the KV cache, but the type was changed to utils::CacheState and now covers non-KV cache kinds (e.g., linear/SSM state) as well. Consider updating the comment (and possibly the accessor name) so it matches the new semantics.
|
|
||
|
|
||
| bool has_linear_attention_states(const std::filesystem::path& models_path, const ov::AnyMap& properties) { | ||
| return get_cache_types(*read_model(models_path, properties)).has_linear(); |
There was a problem hiding this comment.
Used in VLM constructor.
There was a problem hiding this comment.
Model based constructors are needed in this case for VLMPipelineImpl and VLMContinuousBatchingAdapter. Model reading is heavy
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import Type | ||
| import subprocess |
There was a problem hiding this comment.
import subprocess will be flagged by Bandit (B404) in this repo (Bandit runs recursively without excluding tests/). Other test files suppress this with # nosec B404 on the import. Consider adding the same suppression here (or refactoring to reuse the existing helper that already carries the suppression) to avoid CI failures.
| import subprocess | |
| import subprocess # nosec B404 |
Description
Support fixed-size cache state for linear/hybrid attention models.
Core abstraction — CacheTypes and CacheState (utils.hpp, utils.cpp)
Stateful LLM pipeline (pipeline_stateful.cpp, lm_encoding.cpp)
VLM pipeline (pipeline.cpp, inputs_embedder.cpp)
Speculative decoding (fast_draft_strategy.cpp, eagle3_strategy.cpp)
Tests
CI
CVS-181414
Checklist: